{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pytorch Alexnet Example\n",
    "=======================\n",
    "\n",
    "This is a complete example of training an alexnet on pytorch, fully within notebook, and using nothing but widely-used library functions.\n",
    "\n",
    "Warning: this notebook download a full large-scale dataset (places365).  That is too large to do in a practical way on Google Colab, so you need to host this notebook on your own server."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, torchvision, os\n",
    "\n",
    "def train_alexnet_places(num_steps=100000):\n",
    "    print(\"Making alexnet...\")\n",
    "    alexnet = make_untrained_alexnet_places()\n",
    "    alexnet.train()\n",
    "    print(\"Loading datasets...\")\n",
    "    train_loader, val_loader = get_train_and_val_data_loaders()\n",
    "    print(\"Training classifier...\")\n",
    "    checkpointer = make_checkpointing_function(val_loader, checkpoint_dir='checkpoints')\n",
    "    train_classifier(alexnet, train_loader,\n",
    "                     max_iter=num_steps,\n",
    "                     momentum=0.9,\n",
    "                     init_lr=2e-2,\n",
    "                     weight_decay=5e-4,\n",
    "                     monitor=checkpointer)\n",
    "    return alexnet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Untrained Alexnet\n",
    "-----------------\n",
    "\n",
    "This function creates an untrained alexnet, with randomized parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from collections import OrderedDict\n",
    "def make_untrained_alexnet_places():\n",
    "    # channel widths\n",
    "    w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365]\n",
    "    # Alexnet splits channels into groups\n",
    "    groups = [1, 2, 1, 2, 2]\n",
    "    model = nn.Sequential(OrderedDict([\n",
    "        ('conv1', nn.Conv2d(w[0], w[1], kernel_size=11,\n",
    "            stride=4,\n",
    "            groups=groups[0], bias=True)),\n",
    "        ('relu1', nn.ReLU(inplace=True)),\n",
    "        ('pool1', nn.MaxPool2d(kernel_size=3, stride=2)),\n",
    "        ('conv2', nn.Conv2d(w[1], w[2], kernel_size=5, padding=2,\n",
    "            groups=groups[1], bias=True)),\n",
    "        ('relu2', nn.ReLU(inplace=True)),\n",
    "        ('pool2', nn.MaxPool2d(kernel_size=3, stride=2)),\n",
    "        ('conv3', nn.Conv2d(w[2], w[3], kernel_size=3, padding=1,\n",
    "            groups=groups[2], bias=True)),\n",
    "        ('relu3', nn.ReLU(inplace=True)),\n",
    "        ('conv4', nn.Conv2d(w[3], w[4], kernel_size=3, padding=1,\n",
    "            groups=groups[3], bias=True)),\n",
    "        ('relu4', nn.ReLU(inplace=True)),\n",
    "        ('conv5', nn.Conv2d(w[4], w[5], kernel_size=3, padding=1,\n",
    "            groups=groups[4], bias=True)),\n",
    "        ('relu5', nn.ReLU(inplace=True)),\n",
    "        ('pool5', nn.MaxPool2d(kernel_size=3, stride=2)),\n",
    "        ('flatten', nn.Flatten()),\n",
    "        ('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)),\n",
    "        ('relu6', nn.ReLU(inplace=True)),\n",
    "        ('dropout6', nn.Dropout()),\n",
    "        ('fc7', nn.Linear(w[6], w[7], bias=True)),\n",
    "        ('relu7', nn.ReLU(inplace=True)),\n",
    "        ('dropout7', nn.Dropout()),\n",
    "        ('fc8', nn.Linear(w[7], w[8]))\n",
    "    ]))\n",
    "    # Setup the initial parameters randomly\n",
    "    for n, p in model.named_parameters():\n",
    "        if 'bias' in n:\n",
    "            torch.nn.init.zeros_(p)\n",
    "        else:\n",
    "            torch.nn.init.kaiming_normal_(p, nonlinearity='relu')\n",
    "    model.cuda()\n",
    "    model.train()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can call the function to make a network, and then list all the network's trainable parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = make_untrained_alexnet_places()\n",
    "for n, p in a.named_parameters():\n",
    "    print(n, tuple(p.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can save the uninitialized neural network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(a.state_dict(), 'checkpoints/uninitialized_alexnet.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Main Training Loop\n",
    "------------------\n",
    "\n",
    "This is a generic training loop for a classifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_classifier(model, train_data_loader, max_iter,\n",
    "                     momentum=0.9, init_lr=2e-2, weight_decay=5e-4,\n",
    "                     monitor=None):\n",
    "    if monitor is not None:\n",
    "        monitor(model, 0, 0.0, 0.0, 0)\n",
    "    optimizer = torch.optim.SGD(\n",
    "        model.parameters(),\n",
    "        lr=init_lr, momentum=momentum, weight_decay=weight_decay)\n",
    "    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, init_lr, max_iter)\n",
    "    iter_num = 0\n",
    "    while iter_num < max_iter:\n",
    "        for t_input, t_target in train_data_loader:\n",
    "            # Copy data into the gpu\n",
    "            input_var, target_var = [d.cuda() for d in [t_input, t_target]]\n",
    "            # Evaluate model\n",
    "            output = model(input_var)\n",
    "            loss = torch.nn.functional.cross_entropy(output, target_var)\n",
    "            # Perform one step of SGD\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step() # Learning rate schedule\n",
    "            # Check training set accuracy\n",
    "            _, pred = output.max(1)\n",
    "            batch_size = len(t_input)\n",
    "            accuracy = target_var.detach().eq(pred).float().sum().item() / batch_size\n",
    "            # Advance, and print out some stats\n",
    "            iter_num += 1\n",
    "            if monitor is not None:\n",
    "                monitor(model, iter_num, loss, accuracy, batch_size)\n",
    "            if iter_num >= max_iter:\n",
    "                break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data set\n",
    "--------\n",
    "\n",
    "This is the definition of the places data set used for training.\n",
    "If we do not have the files, we download them.  And then we make a\n",
    "DataSet object that defines how to resize, crop, and normalize the images.\n",
    "\n",
    "The DataLoader objects wrap the dataset in a multithreaded streaming\n",
    "object that batches the image data and loads it quickly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_places_data_set(split, crop_size=227, download=True):\n",
    "    dirname = f'datasets/places/{split}'\n",
    "    nfs_source = '/data/vision/torralba/datasets/places/files'\n",
    "    web_source = 'https://dissect.csail.mit.edu/datasets/'\n",
    "    if not os.path.exists(dirname) and download:\n",
    "        if os.path.exists(nfs_source):\n",
    "            os.symlink(nfs_source, 'datasets/places')\n",
    "        else:\n",
    "            os.makedirs(dirname, exist_ok=True)\n",
    "            torchvision.datasets.utils.download_and_extract_archive(\n",
    "                'web_sources' +\n",
    "                'places_%s.zip' % split,\n",
    "                'datasets',\n",
    "                md5=dict(val='593bbc21590cf7c396faac2e600cd30c',\n",
    "                         train='d1db6ad3fc1d69b94da325ac08886a01')[split])\n",
    "    if split == 'train':\n",
    "        cropping_rule = [\n",
    "            torchvision.transforms.RandomCrop(227),\n",
    "            torchvision.transforms.RandomHorizontalFlip() ]\n",
    "    else:\n",
    "        cropping_rule = [torchvision.transforms.CenterCrop(crop_size)]\n",
    "    places_transform = torchvision.transforms.Compose([\n",
    "        torchvision.transforms.Resize(256)\n",
    "        ] + cropping_rule + [\n",
    "        torchvision.transforms.ToTensor(),\n",
    "        torchvision.transforms.Normalize(\n",
    "            [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    ])\n",
    "    return torchvision.datasets.ImageFolder(\n",
    "        dirname, transform=places_transform)\n",
    "\n",
    "def get_train_and_val_data_loaders():\n",
    "    return [\n",
    "        torch.utils.data.DataLoader(\n",
    "            get_places_data_set(split),\n",
    "            batch_size=256, shuffle=(split == 'train'),\n",
    "            num_workers=48, pin_memory=True)\n",
    "        for split in ['train', 'val']\n",
    "    ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generic Evaluation and Checkpointing Utilities\n",
    "----------------------------------------------\n",
    "\n",
    " * **measure_val_accuracy_and_loss** evaluates the model on the holdout set and reports its performance.\n",
    " * **save_model_iteration** saves the current model parameters in a pytorch file.\n",
    " * **make_training_monitor** makes a callback function for periodically evaluating and saving a model during training.\n",
    " * **AverageMeter** tracks averages (e.g., average accuracy, average loss)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def measure_val_accuracy_and_loss(model, val_data_loader):\n",
    "    '''\n",
    "    Evaluates the model (in inference mode) on holdout data.\n",
    "    '''\n",
    "    model.eval()\n",
    "    val_loss, val_acc = AverageMeter(), AverageMeter()\n",
    "    for input, target in val_data_loader:\n",
    "        input_var, target_var = [d.cuda() for d in [input, target]]\n",
    "        with torch.no_grad():\n",
    "            output = model(input_var)\n",
    "            loss = torch.nn.functional.cross_entropy(output, target_var)\n",
    "            _, pred = output.max(1)\n",
    "            accuracy = (target_var.eq(pred)\n",
    "                    ).data.float().sum().item() / input.size(0)\n",
    "        val_acc.update(accuracy, input.size(0))\n",
    "        val_loss.update(loss.data.item(), input.size(0))\n",
    "    return val_acc, val_loss\n",
    "\n",
    "def save_model_iteration(model, iter_num, checkpoint_dir):\n",
    "    '''\n",
    "    Saves the current parameters of the model to a file.\n",
    "    '''\n",
    "    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'iter_{iter_num}.pth'))\n",
    "        \n",
    "def make_checkpointing_function(val_data_loader, checkpoint_dir=None, checkpoint_freq=100):\n",
    "    '''\n",
    "    Makes a callback to monitor training and make checkpoints.\n",
    "    '''\n",
    "    avg_train_accuracy, avg_train_loss = AverageMeter(), AverageMeter()\n",
    "    def monitor(model, iter_num, loss, accuracy, batch_size):\n",
    "        avg_train_accuracy.update(accuracy, batch_size)\n",
    "        avg_train_loss.update(loss, batch_size)\n",
    "        if iter_num % checkpoint_freq == 0:\n",
    "            val_accuracy, val_loss = measure_val_accuracy_and_loss(model, val_data_loader)\n",
    "            if checkpoint_dir is not None:\n",
    "                save_model_iteration(model, iter_num, checkpoint_dir)\n",
    "            print(f'Iter {iter_num}, ' + \n",
    "                  f'train acc {avg_train_accuracy.avg:.3g} loss {avg_train_loss.avg:.3g}, ' +\n",
    "                  f'val acc {val_accuracy.avg:.3g}, loss {val_loss.avg:.3g}')\n",
    "            model.train()\n",
    "    return monitor            \n",
    "            \n",
    "class AverageMeter(object):\n",
    "    '''\n",
    "    To keep running averages.\n",
    "    '''\n",
    "    def __init__(self):\n",
    "        self.reset()\n",
    "    def reset(self):\n",
    "        self.val = 0.\n",
    "        self.avg = 0.\n",
    "        self.sum = 0.\n",
    "        self.count = 0\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        if self.count:\n",
    "            self.avg = self.sum / self.count"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now do the work\n",
    "---------------\n",
    "\n",
    "Try loading alexnet from a checkpoint.  If we have not yet saved a checkpoint snapshot with the number of iterations we want, then train it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations = 100\n",
    "try:\n",
    "    a = make_untrained_alexnet_places()\n",
    "    a.load_state_dict(torch.load(f'checkpoints/iter_{num_iterations}.pth'))\n",
    "except:\n",
    "    a = train_alexnet_places(num_iterations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now view one image - reverse the dataset normalization to get a nice image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dsv = get_places_data_set('val')\n",
    "im, label = dsv[5000]\n",
    "im = im.cuda()\n",
    "# Reverse the normalization\n",
    "unnormalized = (im.cpu().permute(1, 2, 0)\n",
    "    * torch.tensor([0.229, 0.224, 0.225])\n",
    "    + torch.tensor([0.485, 0.456, 0.406]))\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "plt.imshow(unnormalized)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, run the network on the function and print the prediction.\n",
    "\n",
    "Note the network expexts to work in batches, so `im[None]` forms an image batch of size one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a.eval()\n",
    "output = a(im[None])\n",
    "pred = output.max(1)[1][0]\n",
    "\n",
    "print('prediction: ', dsv.classes[pred])\n",
    "print('groundtruth: ', dsv.classes[label])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}